# Import necessary libraries
from plotly.subplots import make_subplots
import plotly.graph_objects as go
import plotly.io as pio
import pandas as pd
import plotly.express as px
import plotly, os, joblib
import numpy as np
import pickle
# Function to generate the plot
def generatePlot(plotName, rows, cols, data, width, height, vertical_spacing, horizontal_spacing,
title_font, marker_size, label_size, tick_size, plot_title, line_color):
# Sorting the data
df_sorted = data
# Setting up the subplot structure
specs = []
for row in range(1, rows + 1):
a = []
for col in range(1, cols + 1):
a.append({"type": "polar"})
specs.append(a)
# Creating the subplots
fig = make_subplots(rows=rows, cols=cols, vertical_spacing=vertical_spacing,
subplot_titles=[i.replace('__', '-') for i in df_sorted.index.tolist()],
horizontal_spacing=horizontal_spacing, specs=specs)
# Adding traces to the subplots
row = 1
col = 1
for model in df_sorted.index:
name = []
value = []
model_score = df_sorted.loc[model]
for score in model_score.index:
if score == "model":
continue
name.append(score)
value.append(model_score.loc[score] * 100)
fig_tem = go.Scatterpolar(r=value, name=model, dtheta=20,
theta=name, fill='toself',
line_color=line_color)
fig.add_trace(fig_tem,
row=row, col=col)
if col == cols:
col = 1
row += 1
else:
col += 1
# Updating layout and annotations
for i in fig['layout']['annotations']:
i['font'] = dict(size=title_font, family="Arial", color='black')
i['borderpad'] = 5
# Finalizing the layout
fig.update_layout(width=width, height=height, font_size=label_size, template="plotly_white",
font_family="Arial",
showlegend=False, margin=dict(t=30, b=15, r=40, l=40,))
fig.update_layout(title={'text': plot_title, 'y': 0.995, 'x': 0.5,
'xanchor': 'center', 'yanchor': 'top',
"font_family": "Arial", "font_size": 10})
# Updating polar axis
fig.update_polars(radialaxis=dict(
visible=True, nticks=7, range=[30, 100],
tickfont=dict(size=tick_size)
),
angularaxis=dict(showticklabels=False, ticks='', linewidth=0.2, showline=True, linecolor='black'))
# Updating traces
fig.update_traces(marker=dict(size=marker_size, line_color="black", color=px.colors.sequential.Viridis),
selector=dict(type='scatterpolar'))
fig.update_polars(angularaxis=dict(showticklabels=True))
return(fig)
# Function to generate spyder plots
def getSpyderPlot(results, cvName, plot_title, line_color):
# Convert results to DataFrame
df = pd.DataFrame({'Accuracy': results["Acc"],
'Balanced Acc': results["Bal_acc"],
'F1': results["F1"],
'Recall': results["recall"],
'Precision': results["precision"],
'Avg precision': results["average_precision"],
'roc_auc': results["roc_auc"],
"model": results["model"]}, index=results["model"])
# Sort the DataFrame
df_sorted = df.sort_values('F1')
# Generate the spyder plot
return(generatePlot(cvName, 2, 4, df_sorted, 850, 450, 0.08, 0.09, 10, 4, 9, 9, plot_title, line_color))
# Specify the path to the results folder downloaded from GitHub.
#resultsFolder='/path/to/results/folder'
resultsFolder='/Users/akshay/Desktop/prof_katia_plots/pain-paper/MLcps-paper/MLcps/generateManuPlots/results'
# change directory to the CLL results folder.
# All the plots will be saved here.
os.chdir(os.path.join(resultsFolder,"CLL"))
# load dataset
results_whole = pickle.load(open("results_whole.pickle",'rb'))
getSpyderPlot(results_whole,"result_whole","CLL","#B9E4E8")
# change directory to the cervical results folder.
# All the plots will be saved here.
os.chdir(os.path.join(resultsFolder,"cervical"))
# load dataset
results_whole = pickle.load(open("results_whole.pickle",'rb'))
getSpyderPlot(results_whole,"result_whole","Cervical","#B9E4E8")
# change directory to the TCGA miRNA results folder.
# All the plots will be saved here.
os.chdir(os.path.join(resultsFolder,"TCGA-BRCA_miRNA"))
# load dataset
results_whole = pickle.load(open("results_whole.pickle",'rb'))
results_test = pickle.load(open("results_test.pickle",'rb'))
getSpyderPlot(results_whole,"result_whole","TCGA-BRCA-miRNA","#B9E4E8")
getSpyderPlot(results_whole,"result_test","TCGA-BRCA-miRNA","#8CBFAA")
# change directory to the TCGA mRNA results folder.
# All the plots will be saved here.
os.chdir(os.path.join(resultsFolder,"TCGA-BRCA_mRNA"))
# load dataset
results_whole = pickle.load(open("results_whole.pickle",'rb'))
results_test = pickle.load(open("results_test.pickle",'rb'))
getSpyderPlot(results_whole,"result_whole","TCGA-BRCA-mRNA","#B9E4E8")
getSpyderPlot(results_whole,"result_test","TCGA-BRCA-mRNA","#8CBFAA")